import matplotlib
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import matplotlib.ticker as ticker
import numpy as np
import pandas as pd
import os
import sys
import pickle
from joblib import dump, load
from math import sqrt
import yaml
import statsmodels.api as sm 
import statsmodels.formula.api as smf
from scipy import stats
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import RandomizedSearchCV
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.linear_model import LinearRegression as reg
import pdb
from pprint import pprint
from statistics import mean
import random
random.seed(50)
import math
import argparse

sys.path.insert(1,'/REDACTED/fairness/code/utilities/')
import UncertaintySimulation as unc
from costCalculatorV2 import *

sys.path.insert(1,'/REDACTED/fairness/code/config')
from configureColors import *

# set plot defaults
plt.style.use('/REDACTED/fairness/code/config/fairness.mplstyle')
fe = fm.FontEntry(
    fname='/REDACTED/fairness/code/config/fonts/cmunrm.ttf',
    name='latex')
fm.fontManager.ttflist.insert(0, fe)
matplotlib.rcParams['font.family'] = fe.name

stanford_palette = ['#f0f4f5', '#d0d8da', '#aabeC6', '#009abb', '#007c92', '#09425a',
                   '#c7d1c6', '#80982a', '#556222', '#b96d12', '#53284f', '#5e3032',
                   '#8c1515', '#dddddd', '#cccccc', '#999999', '#666666', '#333333',
                   '#000000']
color_codes = ['background blue', 'gray blue', 'dark gray-blue', 'accent blue', 'link blue', 'dark blue',
              'light green', 'bright green', 'dark green', 'orange', 'purple', 'maroon',
              'cardinal red', 'line gray', 'light gray', 'gray', 'med gray', 'text gray',
              'black']
cdict = dict(zip(color_codes, [mcolors.to_rgba(c) for c in stanford_palette]))

mcolors.get_named_colors_mapping().update(cdict)

with open('/REDACTED/fairness/code/rf/data/status_quo.pickle','rb') as f:
    status_quo_point = pickle.load(f)

def get_avg(data, varb):
    avg = (data['base_weight']*data['aud_ind']*data[varb]).sum()/(data['base_weight']*data['aud_ind']).sum()

    return avg

def get_annual(data, varb, weights):
    grouped = data.groupby('study_year')[varb].sum().to_dict()
    total = sum({x: weights[x]*grouped[x] for x in weights.keys()}.values())
    annual = total/5

    return annual

def calc_wtd_median(data):
    dat = data[['chg_in_tax_owed_pv','base_weight']].copy()
    dat = dat.sort_values('chg_in_tax_owed_pv')
    dat['cumsum'] = dat['base_weight'].cumsum()
    cutoff = dat['base_weight'].sum()/2
    above = dat[dat['cumsum']>=cutoff].iloc[0]
    below = dat[dat['cumsum']<=cutoff].iloc[0]
    median = (above['chg_in_tax_owed_pv']*above['base_weight']+below['chg_in_tax_owed_pv']*below['base_weight'])/(above['base_weight']+below['base_weight'])
    return median

def get_share_nc(df, black=True):
    if black == True:
        share_nc = (df.noncomp*df.predicted_prob_black*df.base_weight).sum()/(df.predicted_prob_black*df.base_weight).sum()
    else: 
        share_nc = (df.noncomp*df.predicted_prob_nonblack*df.base_weight).sum()/(df.predicted_prob_nonblack*df.base_weight).sum()
    return share_nc

def get_share_c(df, black=True):
    if black == True:
        share_c = 1 - (df.noncomp*df.predicted_prob_black*df.base_weight).sum()/(df.predicted_prob_black*df.base_weight).sum()
    else: 
        share_c = 1 - (df.noncomp*df.predicted_prob_nonblack*df.base_weight).sum()/(df.predicted_prob_nonblack*df.base_weight).sum()
    return share_c

def get_fp(df, black=True):
    if black == True:
        fp = (df.base_weight*df.aud_ind*(1-df.noncomp)*df.predicted_prob_black).sum()/(df.base_weight*(1-df.noncomp)*df.predicted_prob_black).sum()
    else:
        fp = (df.base_weight*df.aud_ind*(1-df.noncomp)*df.predicted_prob_nonblack).sum()/(df.base_weight*(1-df.noncomp)*df.predicted_prob_nonblack).sum()
    return fp

def get_s(df, black=True):
    if black == True:
        s = (df.base_weight*df.aud_ind*df.noncomp*df.predicted_prob_black).sum()/(df.base_weight*df.noncomp*df.predicted_prob_black).sum()
    else:
        s = (df.base_weight*df.aud_ind*df.noncomp*df.predicted_prob_nonblack).sum()/(df.base_weight*df.noncomp*df.predicted_prob_nonblack).sum()
    return s

def get_disp(df):
    disp = (df.aud_ind*df.base_weight*df.predicted_prob_black).sum()/(df.predicted_prob_black*df.base_weight).sum() - (df.aud_ind*df.base_weight*df.predicted_prob_nonblack).sum()/(df.predicted_prob_nonblack*df.base_weight).sum()
    return disp

def get_model_diff_stats(clsDat, regDat):
    '''
    Takes bootstrapped model disparities as inputs
    and outputs mean, standard error, and p-value
    of their difference across runs.
    '''

    linDif = clsDat.reg_fair - regDat.reg_fair
    probDif = clsDat.chen_fair - regDat.chen_fair

    meanLinDif = linDif.mean()
    meanProbDif = probDif.mean()

    linPVal = stats.ttest_1samp(linDif, popmean=0)[1]
    probPVal = stats.ttest_1samp(probDif, popmean=0)[1]

    print('linear estimator mean difference: ' + str(meanLinDif))
    print('linear estimator p value: ' + str(linPVal))

    print('probabilistic estimator mean difference: ' + str(meanProbDif))
    print('probabilistic estimator p value: ' + str(probPVal))

    return meanLinDif, linPVal, meanProbDif, probPVal

def get_models_disp(datapath='/REDACTED/fairness/code/rf/data/',
                    outdir='/REDACTED/audit_descr_stats/data/final_paper_figs/',
                    eitc=True,
                    dep_database=True,
                    errbars=False):

    '''
    Takes trained models and outputs bar plot of disparity
    across models using full research_audits data
    '''

    stream = open('/REDACTED/fairness/code/rf/config/data-config.yaml', 'r')
    out = yaml.safe_load(stream)
    print('config file loaded')

    data = pd.read_csv(datapath + 'clean_rf_data_plus_dep_database.csv')

    if eitc == True:
        data = data.loc[(data.activity_code == 270) | (data.activity_code == 271)]

    feature_vars = [x for x in out['features_plus_dep_database_str'] if x in data.columns]
    features = data[feature_vars]

    # first: unconstrained    
    cls_disps = []
    cls_ses = []
    reg_disps = []
    reg_ses = []

    for i in range(5):
        # load models
        reg = load(datapath + 'EITC_NCMP_RF_Reg_plus_dep_database_train_set_' + str(i) + '.joblib')
        cls = load(datapath + 'EITC_NCMP_RF_Class_100_plus_dep_database_train_set_' + str(i) + '.joblib')

        # gen predictions
        data['reg_sort_var'] = reg.predict(features)
        data['cls_sort_var'] = cls.predict_proba(features)[:, 1]

        popsize = data['base_weight'].sum()
        aud = math.ceil(popsize*(1.45/100))

        # regressor analysis

        data.sort_values(by=['reg_sort_var', 'chg_in_tax_owed'], inplace=True, ascending=False)
        data['cumsum'] = data.base_weight.cumsum()

        # assign initial audit inds to full test set
        copy = data.copy(deep=True)
        copy['aud_ind'] = [1 if x < aud else 0 for x in copy['cumsum']]
        actual = copy.loc[copy.aud_ind==1, 'base_weight'].sum()

        # if there is audit budget remaining, use it on a portion of the
        # next observation up for audit
        if actual/aud < 1:
            row = copy.index[(copy.aud_ind.values != 1).argmax()]
            dif = aud - actual
            dupe = copy.loc[row].copy()
            dupe['aud_ind'] = 1
            dupe['base_weight'] = dif
            dupe['wgt_chg'] = dupe['base_weight']*dupe['chg_in_tax_owed_pv']
            copy.loc[row, 'base_weight'] = copy.loc[row, 'base_weight'] - dif
            copy = copy.append(dupe)

        coef, se = unc.regEstimate(copy, pbvarb='predicted_prob_black', outcome='aud_ind', wvarb='base_weight')

        reg_disps.append(coef)
        reg_ses.append(se)

        #classifier analysis

        data.sort_values(by=['cls_sort_var', 'chg_in_tax_owed'], inplace=True, ascending=False)
        data['cumsum'] = data.base_weight.cumsum()

        # assign initial audit inds to full test set
        copy = data.copy(deep=True)
        copy['aud_ind'] = [1 if x < aud else 0 for x in copy['cumsum']]
        actual = copy.loc[copy.aud_ind==1, 'base_weight'].sum()

        # if there is audit budget remaining, use it on a portion of the
        # next observation up for audit
        if actual/aud < 1:
            row = copy.index[(copy.aud_ind.values != 1).argmax()]
            dif = aud - actual
            dupe = copy.loc[row].copy()
            dupe['aud_ind'] = 1
            dupe['base_weight'] = dif
            dupe['wgt_chg'] = dupe['base_weight']*dupe['chg_in_tax_owed_pv']
            copy.loc[row, 'base_weight'] = copy.loc[row, 'base_weight'] - dif
            copy = copy.append(dupe)

        coef, se = unc.regEstimate(copy, pbvarb='predicted_prob_black', outcome='aud_ind', wvarb='base_weight')

        cls_disps.append(coef)
        cls_ses.append(se)

    # make plots

    # regressor

    X = ['Model 1','Model 2','Model 3','Model 4', 'Model 5']
    X_axis = np.arange(len(X))

    fig, ax = plt.subplots(1, 1, sharex=True, sharey=True, figsize=(10,8))
    plt.scatter(X_axis - 0.2, reg_disps, color='purple')
    if errbars == True:
        plt.errorbar(X_axis-0.2, reg_disps, yerr=[x*1.96 for x in reg_ses], capsize=10, fmt='o', marker='o', color='dark purple')
    plt.xticks(X_axis, X)
    ax.axhline(y=0, color='gray')
    scale_y=0.01
    ticks_y=ticker.FuncFormatter(lambda y, pos: '{0:g}'.format(y/scale_y))
    ax.yaxis.set_major_formatter(ticks_y)
    plt.xlabel("Model")
    plt.ylabel("Disparity (linear, percentage points)")
    plt.title("Regressor Disparity by Model")

    plt.savefig(outdir + 'reg_disp_by_model.png')
    plt.close()

    # classifier

    X = ['Model 1','Model 2','Model 3','Model 4', 'Model 5']
    X_axis = np.arange(len(X))

    fig, ax = plt.subplots(1, 1, sharex=True, sharey=True, figsize=(10,8))
    plt.scatter(X_axis - 0.2, cls_disps, color='purple')
    if errbars == True:
        plt.errorbar(X_axis-0.2, cls_disps, yerr=[x*1.96 for x in cls_ses], capsize=10, fmt='o', marker='o', color='dark purple')
    plt.xticks(X_axis, X)
    ax.axhline(y=0, color='gray')
    scale_y=0.01
    ticks_y=ticker.FuncFormatter(lambda y, pos: '{0:g}'.format(y/scale_y))
    ax.yaxis.set_major_formatter(ticks_y)

    plt.xlabel("Model")
    plt.ylabel("Disparity (linear, percentage points)")
    plt.title("Classifier Disparity by Model")

    plt.savefig(outdir + 'cls_disp_by_model.png')

    # difference

    X = ['Model 1','Model 2','Model 3','Model 4', 'Model 5']
    X_axis = np.arange(len(X))
    diff = [x - y for x, y in zip(reg_disps, cls_disps)]

    fig, ax = plt.subplots(1, 1, sharex=True, sharey=True, figsize=(10,8))
    plt.scatter(X_axis - 0.2, diff, color='dark purple')
    plt.xticks(X_axis, X)
    ax.axhline(y=0, color='gray')
    scale_y=0.01
    ticks_y=ticker.FuncFormatter(lambda y, pos: '{0:g}'.format(y/scale_y))
    ax.yaxis.set_major_formatter(ticks_y)

    plt.xlabel("Model")
    plt.ylabel("Disparity (linear, percentage points)")
    plt.title("Difference in Disparity Across Models (reg - cls)")

    plt.savefig(outdir + 'diff_in_disp_by_model.png')

    # next: constrained

    cls_disps = []
    cls_ses = []
    reg_disps = []
    reg_ses = []

    for i in range(5):
        # load models
        reg = load(datapath + 'EITC_NCMP_RF_Reg_plus_dep_database_train_set_' + str(i) + '.joblib')
        cls = load(datapath + 'EITC_NCMP_RF_Class_100_plus_dep_database_train_set_' + str(i) + '.joblib')

        # gen predictions
        data['reg_sort_var'] = reg.predict(features)
        data['cls_sort_var'] = cls.predict_proba(features)[:, 1]

        popsize = data['base_weight'].sum()
        aud = math.ceil(popsize*(1.45/100))

        # regressor analysis

        data.sort_values(by=['reg_sort_var', 'chg_in_tax_owed'], inplace=True, ascending=False)
        data['cumsum'] = data.base_weight.cumsum()

        # assign initial audit inds to full test set
        copy = data.copy(deep=True)
        copy['aud_ind'] = [1 if x < aud else 0 for x in copy['cumsum']]
        actual = copy.loc[copy.aud_ind==1, 'base_weight'].sum()

        # if there is audit budget remaining, use it on a portion of the
        # next observation up for audit
        if actual/aud < 1:
            row = copy.index[(copy.aud_ind.values != 1).argmax()]
            dif = aud - actual
            dupe = copy.loc[row].copy()
            dupe['aud_ind'] = 1
            dupe['base_weight'] = dif
            dupe['wgt_chg'] = dupe['base_weight']*dupe['chg_in_tax_owed_pv']
            copy.loc[row, 'base_weight'] = copy.loc[row, 'base_weight'] - dif
            copy = copy.append(dupe)

        coef, se = unc.regEstimate(copy, pbvarb='predicted_prob_black', outcome='aud_ind', wvarb='base_weight')

        reg_disps.append(coef)
        reg_ses.append(se)

        #classifier analysis

        data.sort_values(by=['cls_sort_var', 'chg_in_tax_owed'], inplace=True, ascending=False)
        data['cumsum'] = data.base_weight.cumsum()

        # assign initial audit inds to full test set
        copy = data.copy(deep=True)
        copy['aud_ind'] = [1 if x < aud else 0 for x in copy['cumsum']]
        actual = copy.loc[copy.aud_ind==1, 'base_weight'].sum()

        # if there is audit budget remaining, use it on a portion of the
        # next observation up for audit
        if actual/aud < 1:
            row = copy.index[(copy.aud_ind.values != 1).argmax()]
            dif = aud - actual
            dupe = copy.loc[row].copy()
            dupe['aud_ind'] = 1
            dupe['base_weight'] = dif
            dupe['wgt_chg'] = dupe['base_weight']*dupe['chg_in_tax_owed_pv']
            copy.loc[row, 'base_weight'] = copy.loc[row, 'base_weight'] - dif
            copy = copy.append(dupe)

        coef, se = unc.regEstimate(copy, pbvarb='predicted_prob_black', outcome='aud_ind', wvarb='base_weight')

        cls_disps.append(coef)
        cls_ses.append(se)

    # make plots

    # regressor

    X = ['Model 1','Model 2','Model 3','Model 4', 'Model 5']
    X_axis = np.arange(len(X))

    plt.scatter(X_axis - 0.2, reg_disps, color='purple')
    plt.errorbar(X_axis-0.2, reg_disps, yerr=[x*1.96 for x in reg_ses], capsize=10, fmt='o', marker='o', color='dark purple')
    plt.xticks(X_axis, X)
    ax.axhline(y=0, color='gray')
    scale_y=0.01
    ticks_y=ticker.FuncFormatter(lambda y, pos: '{0:g}'.format(y/scale_y))
    ax.yaxis.set_major_formatter(ticks_y)
    plt.xlabel("Model")
    plt.ylabel("Disparity (linear, percentage points)")
    plt.title("Regressor Disparity by Model")

    plt.savefig(outdir + 'res_reg_disp_by_model.png')
    plt.close()

    # classifier
    X = ['Model 1','Model 2','Model 3','Model 4', 'Model 5']
    X_axis = np.arange(len(X))

    plt.scatter(X_axis - 0.2, cls_disps, color='purple')
    plt.errorbar(X_axis-0.2, cls_disps, yerr=[x*1.96 for x in cls_ses], capsize=10, fmt='o', marker='o', color='dark purple')
    plt.xticks(X_axis, X)
    ax.axhline(y=0, color='gray')
    scale_y=0.01
    ticks_y=ticker.FuncFormatter(lambda y, pos: '{0:g}'.format(y/scale_y))
    ax.yaxis.set_major_formatter(ticks_y)

    plt.xlabel("Model")
    plt.ylabel("Disparity (linear, percentage points)")
    plt.title("Classifier Disparity by Model")

    plt.savefig(outdir + 'res_cls_disp_by_model.png')

    # difference

    X = ['Model 1','Model 2','Model 3','Model 4', 'Model 5']
    X_axis = np.arange(len(X))
    diff = [x - y for x, y in zip(reg_disps, cls_disps)]

    plt.scatter(X_axis - 0.2, diff, color='dark purple')
    #plt.errorbar(X_axis-0.2, cls_disps, yerr=[x*1.96 for x in cls_ses], capsize=10, fmt='o', marker='o', color='dark purple')
    plt.xticks(X_axis, X)
    ax.axhline(y=0, color='gray')
    scale_y=0.01
    ticks_y=ticker.FuncFormatter(lambda y, pos: '{0:g}'.format(y/scale_y))
    ax.yaxis.set_major_formatter(ticks_y)

    plt.xlabel("Model")
    plt.ylabel("Disparity (linear, percentage points)")
    plt.title("Difference in Disparity Across Models (")

    plt.savefig(outdir + 'res_diff_in_disp_by_model.png')

    return None

def get_plot_dict_new(acsource='return',
                 eitc=True,
                 activity_code=None,
                 rand_seed=50,
                 bootstrap=False,
                 bootstrap_iters=100,
                 write_bootstrap_fair=True,
                 model='reg', ## options are: 'reg', 'cls', 'oracle'
                 retrain=True,
                 thresh=None,## options are: 100, 1000, 'mean', 'median'
                 dep_database=True,
                 datapath='/REDACTED/fairness/code/rf/data/',
                 outdir='/REDACTED/data/modeled_refactor_temp/',
                 one_fold_params=False,
                 fold_for_params=None,
                 write_test=False,
                 net_rev_version = False,
                 BIFSG_mean = False,
                 config_to_use = "original",
                 mdl_wgt = None,
                 superres_270 = False,
                 superres_271 = False,
                 superres_v2 = False,
                 unwins_costs = False,
                 costs_v2 = False):

    # set seed
    np.random.seed(rand_seed)

    # communicate what model we're running
    if model=='reg' or model=='oracle':
        print('startaxpayer_idg ' + str(model) + ', EITC ' + str(eitc) + ', dep_database ' + str(dep_database))
    elif model=='cls':
        print('startaxpayer_idg ' + str(model) + ', thresh=' + str(thresh) + ', EITC ' + str(eitc) + ', dep_database ' + str(dep_database))

    ## read in config
    os.chdir('/REDACTED/fairness/code/rf/config')
    if config_to_use == "original":    
        stream = open('data-config.yaml', 'r')
    elif config_to_use == "BIFSG":    
        stream = open('data-config-BIFSG.yaml', 'r')
    elif config_to_use == "minus_6":    
        stream = open('data-config-minus_6.yaml', 'r')
    elif config_to_use == "minus_5":
        stream = open('data-config-minus_5.yaml', 'r')
    elif config_to_use == "irs_dep_risk_score_miss":
        stream = open('data-config-irs_dep_risk_score_miss.yaml', 'r')
    elif config_to_use == "top_40_features_total_underreportaxpayer_idg":
        stream = open('data-config-40-total-underreportaxpayer_idg.yaml', 'r')
    elif config_to_use == "ac_271":
        stream = open('data-config-271.yaml', 'r')
    print(os.path.basename(stream.name))
    out = yaml.safe_load(stream)
    print('config file loaded')

    ## read in audit allocations
    if acsource=='individual':
        print('using AC shares calculated from invidual level')
        #with open('/REDACTED/data/metadata/r_by_ac_norm.pickle', 'rb') as f:
            #audit_alloc = pickle.load(f)    
        if eitc==True:
            with open('/REDACTED/data/metadata/audit_shares_within_eitc.pickle', 'rb') as f:
                audit_shares = pickle.load(f)
        elif eitc==False:
            with open('/REDACTED/data/metadata/audit_shares.pickle', 'rb') as f:
                audit_shares = pickle.load(f)
            with open('/REDACTED/data/metadata/audit_shares_eitc.pickle', 'rb') as f:
                audit_shares_eitc = pickle.load(f)
                audit_alloc_eitc = audit_shares_eitc
    elif acsource=='return':
        print('using AC audit shares calculated from return level')
        #with open('/REDACTED/data/metadata/r_by_ac_norm_ret.pickle', 'rb') as f:
            #audit_alloc = pickle.load(f)    
        if eitc==True:
            with open('/REDACTED/data/metadata/audit_shares_within_eitc_ret.pickle', 'rb') as f:
                audit_shares = pickle.load(f)
        elif eitc==False:
            with open('/REDACTED/data/metadata/audit_shares_ret.pickle', 'rb') as f:
                audit_shares = pickle.load(f)
            with open('/REDACTED/data/metadata/audit_shares_eitc_ret.pickle', 'rb') as f:
                audit_shares_eitc = pickle.load(f)
                audit_alloc_eitc=audit_shares_eitc

    ## set array of audit budgets using config file
    budgets = out['budget']

    ## generate dictionary of full pop weights for each study year
    full_data = pd.read_csv(datapath + 'clean_rf_data.csv')
    if eitc == False:
        full_pop_wgts = full_data.groupby('study_year')['base_weight'].sum().to_dict()
        print('full pop weights are:')
        print(full_pop_wgts)
    elif eitc == True and superres_v2 == False:
        full_pop_wgts = full_data.loc[(full_data.activity_code == 270) | (full_data.activity_code == 271)].groupby('study_year')['base_weight'].sum().to_dict()
        print('full pop weights are:')
        print(full_pop_wgts)
    elif eitc == True and superres_v2 == True:
        full_pop_wgts = full_data.loc[full_data.activity_code == 270].groupby('study_year')['base_weight'].sum().to_dict()
        print('full pop weights are:')
        print(full_pop_wgts)

    ## read in data
    train_dict = {}
    test_dict = {}
    yr_wgts_dict = {}

    for i in range(5):
        train_dict['train_' + str(i)] = pd.read_csv(datapath + 'train_data_' + str(i) + '_eitc_' + str(eitc) + '_dep_database_' + str(dep_database) + '.csv')
        test_dict['test_' + str(i)] = pd.read_csv(datapath + 'test_data_' + str(i) + '_eitc_' + str(eitc) + '_dep_database_' + str(dep_database) + '.csv')

        if eitc==True and superres_v2 == False:
            train_dict['train_' + str(i)] = train_dict['train_' + str(i)].loc[(train_dict['train_' + str(i)].activity_code==270) | (train_dict['train_' + str(i)].activity_code==271)]
            test_dict['test_' + str(i)] = test_dict['test_' + str(i)].loc[(test_dict['test_' + str(i)].activity_code==270) | (test_dict['test_' + str(i)].activity_code==271)]
        if eitc==True and superres_v2 == True:
            train_dict['train_' + str(i)] = train_dict['train_' + str(i)].loc[train_dict['train_' + str(i)].activity_code==270]
            test_dict['test_' + str(i)] = test_dict['test_' + str(i)].loc[test_dict['test_' + str(i)].activity_code==270]

        if activity_code is not None:
            train_dict['train_' + str(i)] = train_dict['train_' + str(i)].loc[(train_dict['train_' + str(i)].activity_code==activity_code)]
            test_dict['test_' + str(i)] = test_dict['test_' + str(i)].loc[(test_dict['test_' + str(i)].activity_code==activity_code)]

        ## add cost data to df
        if unwins_costs == False:
            if costs_v2==False:
                test_dict['test_' + str(i)] = getCostsACOnly(test_dict['test_' + str(i)],acvarb='activity_code',median=False,wins=True)
            else:
                test_dict['test_' + str(i)] = getCostsACOnly(test_dict['test_' + str(i)],acvarb='activity_code',median=False,wins=True,v2=True)
        else:
            if costs_v2==False:
                test_dict['test_' + str(i)] = getCostsACOnly(test_dict['test_' + str(i)],acvarb='activity_code',median=False,wins=False)
            else:
                test_dict['test_' + str(i)] = getCostsACOnly(test_dict['test_' + str(i)],acvarb='activity_code',median=False,wins=False,v2=True)

        ## define no change varbs
        test_dict['test_' + str(i)]['nochg_0'] = [1 if x==0 else 0 for x in test_dict['test_' + str(i)].chg_in_tax_owed_pv]
        test_dict['test_' + str(i)]['nochg_lt10'] = [1 if (x>=0) & (x<=10) else 0 for x in test_dict['test_' + str(i)].chg_in_tax_owed_pv]
        test_dict['test_' + str(i)]['nochg_lt100'] = [1 if (x>=0) & (x<=100) else 0 for x in test_dict['test_' + str(i)].chg_in_tax_owed_pv]
        test_dict['test_' + str(i)]['nochg_100abs'] = [1 if (x>=-100) & (x<=100) else 0 for x in test_dict['test_' + str(i)].chg_in_tax_owed_pv]
        test_dict['test_' + str(i)]['negchg'] = [1 if x<0 else 0 for x in test_dict['test_' + str(i)].chg_in_tax_owed_pv]

    if retrain == True and model != 'oracle':
        datapath1 = datapath
        for i in range(5):
            datapath = datapath1
            # read in hyperparams
            if eitc == True and dep_database == True and model == 'reg' and one_fold_params == False:
                params = pd.read_pickle(datapath + 'eitc_dep_database_reg_best_params_train_fold_' + str(i) + '.pickle')
            elif eitc == True and dep_database == True and model == 'cls' and one_fold_params == False:
                params = pd.read_pickle(datapath + 'eitc_dep_database_cls_' + str(thresh) + '_best_params_train_fold_' + str(i) + '.pickle')
            elif eitc == True and dep_database == True and model == 'reg' and one_fold_params == True:
                params = pd.read_pickle(datapath + 'eitc_dep_database_reg_best_params_train_fold_' + str(fold_for_params) + '.pickle')
            elif eitc == True and dep_database == True and model == 'cls' and one_fold_params == True:
                params = pd.read_pickle(datapath + 'eitc_dep_database_cls_' + str(thresh) + '_best_params_train_fold_' + str(fold_for_params) + '.pickle')
            else:
                print('code is not built out to handle this model type yet.')

            # define features, labels
            if dep_database == True:
                feature_vars = [x for x in out['features_plus_dep_database_str'] if x in train_dict['train_' + str(i)].columns]
                features = train_dict['train_' + str(i)][feature_vars]
            else:
                feature_vars = [x for x in out['features_str'] if x in train_dict['train_' + str(i)].columns]
                features = train_dict['train_' + str(i)][feature_vars]

            if model == 'reg':
                labels = train_dict['train_' + str(i)]['chg_in_tax_owed_pv']
                mdl = RandomForestRegressor(**params)
            elif model == 'cls':
                train_dict['train_' + str(i)]['tc_' + str(thresh)] = [1 if x>=thresh else 0 for x in train_dict['train_' + str(i)].chg_in_tax_owed_pv]
                labels = train_dict['train_' + str(i)]['tc_' + str(thresh)]
                mdl = RandomForestClassifier(**params)

            # fit model
            if mdl_wgt is not None:
                weights = train_dict['train_' + str(i)][mdl_wgt]
                mdl.fit(features, labels, sample_weight = weights)
            else:
                mdl.fit(features, labels)
            if config_to_use != 'original':
                datapath = outdir
            if eitc==True and dep_database==True and model == 'cls':
                dump(mdl, datapath + 'EITC_NCMP_RF_Class_' + str(thresh) + '_plus_dep_database_train_set_' + str(i) + '.joblib')
            elif eitc==True and dep_database==True and model == 'reg':
                dump(mdl, datapath + 'EITC_NCMP_RF_Reg_plus_dep_database_train_set_' + str(i) + '.joblib')

            print('model fitted')

    # generataxpayer_idg predictions to use for audit allocation
    if config_to_use != 'original':
        datapath = outdir
    for i in range(5):
        if BIFSG_mean == True:
            rf_eitc_df = full_data[(full_data['activity_code']==270) | (full_data['activity_code']==271)]
            mean_BIFSG = (rf_eitc_df['predicted_prob_black']*rf_eitc_df['base_weight']).sum()/(rf_eitc_df['base_weight']).sum()
            test_dict['test_' + str(i)]['predicted_prob_black_original'] = test_dict['test_' + str(i)]['predicted_prob_black']
            test_dict['test_' + str(i)]['predicted_prob_black'] = mean_BIFSG
        # define features
        if dep_database == True:
            feature_vars = [x for x in out['features_plus_dep_database_str'] if x in test_dict['test_' + str(i)].columns]
            features = test_dict['test_' + str(i)][feature_vars]
        else:
            feature_vars = [x for x in out['features_str'] if x in test_dict['test_' + str(i)].columns]
            features = test_dict['test_' + str(i)][feature_vars]
        # gen predictions - if net_rev = true, then subtract cost from adjustment
        if eitc==True and dep_database==True and model == 'cls':
            mdl = load(datapath + 'EITC_NCMP_RF_Class_' + str(thresh) + '_plus_dep_database_train_set_' + str(i) + '.joblib')
            test_dict['test_' + str(i)]['sort_var'] = mdl.predict_proba(features)[:, 1]
        elif eitc==True and dep_database==True and model == 'reg' and net_rev_version == False:
            mdl = load(datapath + 'EITC_NCMP_RF_Reg_plus_dep_database_train_set_' + str(i) + '.joblib')
            test_dict['test_' + str(i)]['sort_var'] = mdl.predict(features)
        elif eitc==True and dep_database==True and model == 'reg' and net_rev_version == True:
            mdl = load(datapath + 'EITC_NCMP_RF_Reg_plus_dep_database_train_set_' + str(i) + '.joblib')
            test_dict['test_' + str(i)]['sort_var'] = mdl.predict(features)
            test_dict['test_' + str(i)]['sort_var'] = test_dict['test_' + str(i)]['sort_var'] - test_dict['test_' + str(i)]['exp_cost']
        elif model == 'oracle' and net_rev_version == False:
            test_dict['test_' + str(i)]['sort_var'] = test_dict['test_' + str(i)]['chg_in_tax_owed_pv']
        elif model == 'oracle' and net_rev_version == True:
            test_dict['test_' + str(i)]['sort_var'] = test_dict['test_' + str(i)]['chg_in_tax_owed_pv']
            test_dict['test_' + str(i)]['sort_var'] = test_dict['test_' + str(i)]['sort_var'] - test_dict['test_' + str(i)]['exp_cost']            
        # sort each test set by sort var and gen cumsum varb

        test_dict['test_' + str(i)].sort_values(by=['sort_var', 'chg_in_tax_owed'], inplace=True, ascending=False)
        test_dict['test_' + str(i)]['cumsum'] = test_dict['test_' + str(i)].base_weight.cumsum()
        test_dict['test_' + str(i)]['wgt_chg'] = [x*y for x, y in zip(test_dict['test_' + str(i)].base_weight, test_dict['test_' + str(i)].chg_in_tax_owed_pv)]

        # save predictions
        #if eitc == True and dep_database == True and net_rev_version == False:
        #    dftest = pd.DataFrame(test_dict['test_'+str(i)])
        #    dftest.to_csv(datapath+"test_fold_"+str(i)+"_"+str(model)+"_plus_dep_database_pred_new_bifsg.csv")
        #elif eitc == True and dep_database == True and model != 'cls' and net_rev_version == True:
        #    dftest = pd.DataFrame(test_dict['test_'+str(i)])
        #    dftest.to_csv(datapath+"net_rev/test_fold_"+str(i)+"_"+str(model)+"_plus_dep_database_pred_new_bifsg.csv")        
        if BIFSG_mean == True:
            test_dict['test_' + str(i)]['predicted_prob_black'] = test_dict['test_' + str(i)]['predicted_prob_black_original']
        # gen yr_wgts dictionary for each test fold
        test_wgts = test_dict['test_' + str(i)].groupby('study_year')['base_weight'].sum().to_dict()
        yr_wgts_dict['yr_wgts_' + str(i)] = {x: full_pop_wgts[x]/test_wgts[x] for x in full_pop_wgts.keys()}
        print('year weights for fold ' + str(i) + ' are:')
        print(yr_wgts_dict['yr_wgts_' + str(i)])

    full_data = pd.concat(test_dict.values(), ignore_index=True)
    full_data.sort_values(by=['sort_var', 'chg_in_tax_owed'], inplace=True, ascending=False)
    full_data['cumsum'] = full_data.base_weight.cumsum()
    full_data['wgt_chg'] = [x*y for x, y in zip(full_data.base_weight, full_data.chg_in_tax_owed_pv)]

    #####################################################################
    ##### UNRESTRICTED MODELS ###########################################
    #####################################################################

    outcome_dict = {}
    five_fold = {'f0_revenue': [np.nan]*len(budgets),
                 'f1_revenue': [np.nan]*len(budgets),
                 'f2_revenue': [np.nan]*len(budgets),
                 'f3_revenue': [np.nan]*len(budgets),
                 'f4_revenue': [np.nan]*len(budgets),
                 'f0_reg_fair': [np.nan]*len(budgets),
                 'f1_reg_fair': [np.nan]*len(budgets),
                 'f2_reg_fair': [np.nan]*len(budgets),
                 'f3_reg_fair': [np.nan]*len(budgets),
                 'f4_reg_fair': [np.nan]*len(budgets),
                 'f0_chen_fair': [np.nan]*len(budgets),
                 'f1_chen_fair': [np.nan]*len(budgets),
                 'f2_chen_fair': [np.nan]*len(budgets),
                 'f3_chen_fair': [np.nan]*len(budgets),
                 'f4_chen_fair': [np.nan]*len(budgets)
}

    for varb in out['5_fold_outcomes']:
        outcome_dict[varb + '_dict'] = {}
        for i in budgets:
            if 'share' in varb:
                pass
            else:
                outcome_dict[varb + '_dict'][str(i) + 'p'] = [np.nan]*5

    for varb in out['bootstrap_outcomes']:
        outcome_dict[varb + '_dict'] = {}
        for i in budgets:
            if 'share' in varb:
                pass
            else:
                outcome_dict[varb + '_dict'][str(i) + 'p'] = [np.nan]*bootstrap_iters

    plot_dict = {}

    for i in budgets:

        # gen plot dict
        for varb in out['full_data_outcomes']:
            if 'share' in varb:
                pass
            else:
                plot_dict[str(i) + 'p_mean_' + varb + '_unres_' + str(model)] = []
                plot_dict[str(i) + 'p_std_' + varb + '_unres_' + str(model)] = []

        for varb in out['bootstrap_outcomes']:
            if 'share' in varb:
                pass
            else:
                plot_dict[str(i) + 'p_mean_' + varb + '_unres_' + str(model)] = []
                plot_dict[str(i) + 'p_std_' + varb + '_unres_' + str(model)] = []

        for varb in out['5_fold_outcomes']:
            if 'share' in varb:
                pass
            else:
                plot_dict[str(i) + 'p_mean_' + varb + '_unres_' + str(model)] = []
                plot_dict[str(i) + 'p_std_' + varb + '_unres_' + str(model)] = []

        # startaxpayer_idg with full data
        # determine audit budget for full test set
        popsize = full_data['base_weight'].sum()
        aud = math.ceil(popsize*(i/100))

        # assign initial audit inds to full test set
        copy = full_data.copy(deep=True)
        copy['aud_ind'] = [1 if x < aud else 0 for x in copy['cumsum']]
        actual = copy.loc[copy.aud_ind==1, 'base_weight'].sum()

        # if there is audit budget remaining, use it on a portion of the
        # next observation up for audit
        if actual/aud < 1:
            row = copy.index[(copy.aud_ind.values != 1).argmax()]
            dif = aud - actual
            dupe = copy.loc[row].copy()
            dupe['aud_ind'] = 1
            dupe['base_weight'] = dif
            dupe['wgt_chg'] = dupe['base_weight']*dupe['chg_in_tax_owed_pv']
            copy.loc[row, 'base_weight'] = copy.loc[row, 'base_weight'] - dif
            copy = copy.append(dupe)

        # write here: if i = 1.45, output copy, dynamically named
        if i == 1.45:
            copy.to_csv(outdir + 'unres_' + str(model) + '_selected_taxpayer_ids.csv')

        # write full test set fairness measures and standard errors to plot dict

        plot_dict[str(i) + 'p_mean_full_chen_fair_unres_' + str(model)] = unc.chenEstimate(copy, pbvarb='predicted_prob_black',outcome='aud_ind', wvarb='base_weight')
        coef, se = unc.regEstimate(copy, pbvarb='predicted_prob_black',outcome='aud_ind', wvarb='base_weight')
        plot_dict[str(i) + 'p_mean_full_reg_fair_unres_' + str(model)] = coef

        seChen, seReg = unc.getSEs(copy, pbvarb='predicted_prob_black', outcome='aud_ind', wvarb='base_weight')

        plot_dict[str(i) + 'p_std_full_chen_fair_unres_' + str(model)] = seChen
        plot_dict[str(i) + 'p_std_full_reg_fair_unres_' + str(model)] = seReg

        if model == 'cls' and i == 1.45:
            copy['noncomp'] = [1 if x > 100 else 0 for x in copy.chg_in_tax_owed_pv]
            copy['predicted_prob_nonblack'] = 1 - copy['predicted_prob_black']

            fp_cls_b = get_fp(copy, black=True)
            fp_cls_nb = get_fp(copy, black=False)
            s_cls_b = get_s(copy, black=True)
            s_cls_nb = get_s(copy, black=False)
            cb = get_share_c(copy, black=True)
            cnb = get_share_c(copy, black=False)
            ncb = get_share_nc(copy, black=True)
            ncnb = get_share_nc(copy, black=False)
            disp_cls = get_disp(copy)

            table = pd.DataFrame(columns = ['rownames', 'Black', 'NonBlack', 'Disparity'])
            table.rownames = ['False-positive Rate', 'Sensitivity', 'Share Compliant', 'Observed Disparity']
            table.Black = [fp_cls_b, s_cls_b, cb, np.nan]
            table.NonBlack = [fp_cls_nb, s_cls_nb, cnb, np.nan]
            table.Disparity = [cb*(fp_cls_b - fp_cls_nb)*100, (1-cb)*(s_cls_b - s_cls_nb)*100, (cnb - cb)*(s_cls_nb - fp_cls_nb)*100, disp_cls*100]
            table[['Black', 'NonBlack', 'Disparity']] = table[['Black', 'NonBlack', 'Disparity']].round(6)

            with open(outdir + 'cls_unres_decomp.tex', 'w') as tf:
                tf.write(table.to_latex(index=False))

        # next, 5-fold estimates

        for j in range(5):
            copy = test_dict['test_' + str(j)].copy(deep=True)
            popsize = copy['base_weight'].sum()
            aud = math.ceil(popsize*(i/100))

            copy['aud_ind'] = [1 if x < aud else 0 for x in copy['cumsum']]
            actual = copy.loc[copy.aud_ind==1, 'base_weight'].sum()

            # if there is audit budget remaining, use it on a portion of the
            # next observation up for audit
            if actual/aud < 1:
                row = copy.index[(copy.aud_ind.values != 1).argmax()]
                dif = aud - actual
                dupe = copy.loc[row].copy()
                dupe['aud_ind'] = 1
                dupe['base_weight'] = dif
                dupe['wgt_chg'] = dupe['base_weight']*dupe['chg_in_tax_owed_pv']
                copy.loc[row, 'base_weight'] = copy.loc[row, 'base_weight'] - dif
                copy = copy.append(dupe)

            # define some varbs
            copy['total_ref_adj'] = copy['aud_ind'] * copy['base_weight'] * copy['ref_cred_amt_dif_pv']
            copy['revenue'] = copy['aud_ind']*copy['wgt_chg']
            copy['cost'] = copy['aud_ind']*copy['base_weight']*copy['exp_cost']
            copy['net_rev'] = copy['revenue'] - copy['cost']

            for varb in out['annualized_varb_list']:
                outcome_dict['5_fold_' + varb + '_dict'][str(i) + 'p'][j] = get_annual(copy, varb, yr_wgts_dict['yr_wgts_' + str(j)])

            for varb in out['avg_varb_list']:
                outcome_dict['5_fold_' + varb + '_dict'][str(i) + 'p'][j] = get_avg(copy, varb)

            # Compute fairness using weighted chen + reg estimators
            outcome_dict['5_fold_chen_fair_dict'][str(i) + 'p'][j] = unc.chenEstimate(copy, pbvarb='predicted_prob_black', outcome='aud_ind', wvarb='base_weight')
            coef, se = unc.regEstimate(copy, pbvarb='predicted_prob_black', outcome='aud_ind', wvarb='base_weight')
            outcome_dict['5_fold_reg_fair_dict'][str(i) + 'p'][j] = coef

            five_fold['f' + str(j) + '_revenue'][budgets.index(i)] = get_annual(copy, 'revenue', yr_wgts_dict['yr_wgts_' + str(j)])
            five_fold['f' + str(j) + '_chen_fair'][budgets.index(i)] = unc.chenEstimate(copy, pbvarb='predicted_prob_black', outcome='aud_ind', wvarb='base_weight')
            five_fold['f' + str(j) + '_reg_fair'][budgets.index(i)] = coef

        for varb in out['5_fold_outcomes']:
            if 'share' in varb:
                pass
            else:
                plot_dict[str(i) + 'p_mean_' + varb + '_unres_' + str(model)].append('{:.4f}'.format(mean(outcome_dict[varb + '_dict'][str(i) + 'p'])))
                plot_dict[str(i) + 'p_std_' + varb + '_unres_' + str(model)].append('{:.4f}'.format(np.std(outcome_dict[varb + '_dict'][str(i) + 'p'])/np.sqrt(5)))

        # last, bootstrap estimates
        if bootstrap == True:
            for j in range(bootstrap_iters):
                fold = full_data.sample(frac=1, replace=True, random_state=j)
                fold.sort_values(by=['sort_var', 'chg_in_tax_owed'], inplace=True, ascending=False)
                # added next line on 4/27/23 to fix ValueError thrown by 'copy.loc[row, 'base_weight'] = copy.loc[row, 'base_weight'] - dif' below
                fold.reset_index(inplace = True, drop = True)
                fold['cumsum'] = fold.base_weight.cumsum()
                fold['wgt_chg'] = [x*y for x, y in zip(fold.base_weight, fold.chg_in_tax_owed_pv)]

                fold_wgts = fold.groupby('study_year')['base_weight'].sum().to_dict()
                yr_wgts = {x: full_pop_wgts[x]/fold_wgts[x] for x in full_pop_wgts.keys()}

                # determine audit budget
                popsize = fold['base_weight'].sum()
                aud = math.ceil(popsize*(i/100))

                # assign initial audit inds
                copy = fold.copy(deep=True)
                copy['aud_ind'] = [1 if x < aud else 0 for x in copy['cumsum']]
                actual = copy.loc[copy.aud_ind==1, 'base_weight'].sum()

                # if there is audit budget remaining, use it on a portion of the
                # next observation up for audit
                if actual/aud < 1:
                    row = copy.index[(copy.aud_ind.values != 1).argmax()]
                    dif = aud - actual
                    dupe = copy.loc[row].copy()
                    dupe['aud_ind'] = 1
                    dupe['base_weight'] = dif
                    dupe['wgt_chg'] = dupe['base_weight']*dupe['chg_in_tax_owed_pv']
                    print("hey")
                    print(copy.loc[row, 'base_weight'])
                    copy.loc[row, 'base_weight'] = copy.loc[row, 'base_weight'] - dif
                    copy = copy.append(dupe)

                # define some varbs
                copy['total_ref_adj'] = copy['aud_ind'] * copy['base_weight'] * copy['ref_cred_amt_dif_pv']
                copy['revenue'] = copy['aud_ind']*copy['wgt_chg']
                copy['cost'] = copy['aud_ind']*copy['base_weight']*copy['exp_cost']
                copy['net_rev'] = copy['revenue'] - copy['cost']

                for varb in out['annualized_varb_list']:
                    outcome_dict['bootstrap_' + varb + '_dict'][str(i) + 'p'][j] = get_annual(copy, varb, yr_wgts)

                for varb in out['avg_varb_list']:
                    outcome_dict['bootstrap_' + varb + '_dict'][str(i) + 'p'][j] = get_avg(copy, varb)

                # Compute fairness using weighted chen + reg estimators
                outcome_dict['bootstrap_chen_fair_dict'][str(i) + 'p'][j] = unc.chenEstimate(copy, pbvarb='predicted_prob_black', outcome='aud_ind', wvarb='base_weight')
                coef, se = unc.regEstimate(copy, pbvarb='predicted_prob_black', outcome='aud_ind', wvarb='base_weight')
                outcome_dict['bootstrap_reg_fair_dict'][str(i) + 'p'][j] = coef

            for varb in out['bootstrap_outcomes']:
                if 'share' in varb:
                    pass
                else:
                    plot_dict[str(i) + 'p_mean_' + varb + '_unres_' + str(model)].append('{:.4f}'.format(mean(outcome_dict[varb + '_dict'][str(i) + 'p'])))
                    plot_dict[str(i) + 'p_std_' + varb + '_unres_' + str(model)].append('{:.4f}'.format(np.std(outcome_dict[varb + '_dict'][str(i) + 'p'])))

            if i == 1.45 and bootstrap == True and write_bootstrap_fair == True:
                bootstrap_dat = pd.DataFrame(columns=['reg_fair', 'chen_fair'])
                bootstrap_dat.chen_fair = outcome_dict['bootstrap_chen_fair_dict'][str(i) + 'p']
                bootstrap_dat.reg_fair = outcome_dict['bootstrap_reg_fair_dict'][str(i) + 'p']
                bootstrap_dat.to_csv(outdir + 'unres_' + str(model) + '_bootstrap_fair_obs.csv')

    with open(outdir + 'unres_' + str(model) + '_fold_by_fold_rev_and_fair.pickle', 'wb') as handle:
        pickle.dump(five_fold, handle, protocol=pickle.HIGHEST_PROTOCOL)

    if model=='oracle' or model=='reg':
        if eitc==False and activity_code is None:
            with open(outdir + 'unres_' + str(model) + '_output_new_bifsg.pickle', 'wb') as handle:
                pickle.dump(plot_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
        elif eitc==True and dep_database==False and activity_code is None:
            with open(outdir + 'eitc_unres_' + str(model) + '_output_new_bifsg.pickle', 'wb') as handle:
                pickle.dump(plot_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
        elif eitc==True and dep_database==True and activity_code is None:
            with open(outdir + 'eitc_unres_' + str(model) + '_plus_dep_database_output_new_bifsg.pickle', 'wb') as handle:
                pickle.dump(plot_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
        elif activity_code is not None:
            with open(outdir + 'unres_' + str(activity_code) + '_' + str(model) + '_dep_database_' + str(dep_database) + '_eitc_' + str(eitc) + '_output_new_bifsg.pickle', 'wb') as handle:
                pickle.dump(plot_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

    elif model=='cls':
        if eitc==False and activity_code is None:
            with open(outdir + 'unres_' + str(model) + '_' + str(thresh) + '_output_new_bifsg.pickle', 'wb') as handle:
                pickle.dump(plot_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
        elif eitc==True and dep_database==False and activity_code is None:
            with open(outdir + 'eitc_unres_' + str(model) + '_' + str(thresh) + '.pickle', 'wb') as handle:
                pickle.dump(plot_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
        elif eitc==True and dep_database==True and activity_code is None:
            with open(outdir + 'eitc_unres_' + str(model) + '_' + str(thresh) + '_plus_dep_database_output_new_bifsg.pickle', 'wb') as handle:
                pickle.dump(plot_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
        elif activity_code is not None:
            with open(outdir + 'unres_' + str(activity_code) + '_' + str(model) + '_dep_database_' + str(dep_database) + '_eitc_' + str(eitc) + '_output_new_bifsg.pickle', 'wb') as handle:
                pickle.dump(plot_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

    if activity_code is not None:
        return plot_dict

    ##################################################################
    ######## Restricted models #######################################
    ##################################################################

    outcome_dict = {}
    five_fold = {'f0_revenue': [np.nan]*len(budgets),
                 'f1_revenue': [np.nan]*len(budgets),
                 'f2_revenue': [np.nan]*len(budgets),
                 'f3_revenue': [np.nan]*len(budgets),
                 'f4_revenue': [np.nan]*len(budgets),
                 'f0_reg_fair': [np.nan]*len(budgets),
                 'f1_reg_fair': [np.nan]*len(budgets),
                 'f2_reg_fair': [np.nan]*len(budgets),
                 'f3_reg_fair': [np.nan]*len(budgets),
                 'f4_reg_fair': [np.nan]*len(budgets),
                 'f0_chen_fair': [np.nan]*len(budgets),
                 'f1_chen_fair': [np.nan]*len(budgets),
                 'f2_chen_fair': [np.nan]*len(budgets),
                 'f3_chen_fair': [np.nan]*len(budgets),
                 'f4_chen_fair': [np.nan]*len(budgets)
}

    for varb in out['5_fold_outcomes']:
        outcome_dict[varb + '_dict'] = {}
        for i in budgets:
            if 'share' in varb:
                pass
            else:
                outcome_dict[varb + '_dict'][str(i) + 'p'] = [np.nan]*5

    for varb in out['bootstrap_outcomes']:
        outcome_dict[varb + '_dict'] = {}
        for i in budgets:
            if 'share' in varb:
                pass
            else:
                outcome_dict[varb + '_dict'][str(i) + 'p'] = [np.nan]*bootstrap_iters

    plot_dict = {}
    for i in budgets:
        # gen plot dict
        for varb in out['full_data_outcomes']:
            if 'share' in varb:
                pass
            else:
                plot_dict[str(i) + 'p_mean_' + varb + '_res_' + str(model)] = []
                plot_dict[str(i) + 'p_std_' + varb + '_res_' + str(model)] = []

        for varb in out['bootstrap_outcomes']:
            if 'share' in varb:
                pass
            else:
                plot_dict[str(i) + 'p_mean_' + varb + '_res_' + str(model)] = []
                plot_dict[str(i) + 'p_std_' + varb + '_res_' + str(model)] = []

        for varb in out['5_fold_outcomes']:
            if 'share' in varb:
                pass
            else:
                plot_dict[str(i) + 'p_mean_' + varb + '_res_' + str(model)] = []
                plot_dict[str(i) + 'p_std_' + varb + '_res_' + str(model)] = []

        # startaxpayer_idg with full data

        # determine audit budget for full test set
        popsize = full_data['base_weight'].sum()
        aud = math.ceil(popsize*(i/100))

        # create empty dataframe to store results
        col_list = full_data.columns.tolist() + ['aud_ind', 'wgt_chg']
        col_list = list(set(col_list))
        fair_df = pd.DataFrame(columns=col_list)
        if superres_270 == True:
            audit_shares = {70.0: 1}
        if superres_271 == True:
            audit_shares = {71.0: 1}
        for ac in audit_shares:
            # assign initial audit inds
            copy = full_data.loc[full_data.activity_code==(ac+200)].copy(deep=True)
            copy.sort_values(by=['sort_var', 'chg_in_tax_owed'], inplace=True, ascending=False)
            copy['wgt_chg'] = [x*y for x, y in zip(copy.base_weight, copy.chg_in_tax_owed_pv)]
            copy['cumsum'] = copy.base_weight.cumsum()

            # determine audit budget
            popsize = copy['base_weight'].sum()
            #aud = math.ceil(popsize*audit_alloc[ac]*alpha)
            aud = math.ceil(full_data.base_weight.sum()*audit_shares[ac]*(i/100))

            # get initial audit assignment
            copy['aud_ind'] = [1 if x < aud else 0 for x in copy['cumsum']]

            # determine how much of audit budget used so far
            actual = copy.loc[copy.aud_ind==1, 'base_weight'].sum()

            # if there is any audit budget remaining, use last of it on portion of next research_audits observation
            if actual/aud < 1:
                # locate row corresponding to next observation up for audit
                row = copy.index[(copy.aud_ind.values != 1).argmax()]
                # determine remaining audit budget
                dif = aud - actual
                # copy row corresponding to next observation up for audit and spend remaining audit budget on it
                dupe = copy.loc[row].copy()
                dupe['aud_ind'] = 1
                dupe['base_weight'] = dif
                dupe['wgt_chg'] = dupe['base_weight']*dupe['chg_in_tax_owed_pv']
                # adjust weight on non-audited remainder of row
                copy.loc[row, 'base_weight'] = copy.loc[row, 'base_weight'] - dif
                # append audited portion of row to copy
                copy = copy.append(dupe)          

            # append audit assignments from this activity code to the results dataframe for this fold and budget
            fair_df = fair_df.append(copy)
            # output csv of fair_df if 1.45- also change in alt_outcome, dynamic names
            if i == 1.45:
                fair_df.to_csv(outdir + 'res_' + str(model) + '_selected_taxpayer_ids.csv')  

        plot_dict[str(i) + 'p_mean_full_chen_fair_res_' + str(model)] = unc.chenEstimate(fair_df, pbvarb='predicted_prob_black',outcome='aud_ind', wvarb='base_weight')
        fair_df.predicted_prob_black = fair_df.predicted_prob_black.astype(float)
        fair_df['aud_ind']=pd.to_numeric(fair_df['aud_ind'])
        fair_df['base_weight']=pd.to_numeric(fair_df['base_weight'])
        coef, se = unc.regEstimate(fair_df, pbvarb='predicted_prob_black',outcome='aud_ind', wvarb='base_weight')
        plot_dict[str(i) + 'p_mean_full_reg_fair_res_' + str(model)] = coef

        seChen, seReg = unc.getSEs(fair_df, pbvarb='predicted_prob_black', outcome='aud_ind', wvarb='base_weight')

        plot_dict[str(i) + 'p_std_full_chen_fair_res_' + str(model)] = seChen
        plot_dict[str(i) + 'p_std_full_reg_fair_res_' + str(model)] = seReg

        if model == 'cls' and i == 1.45:
            fair_df['noncomp'] = [1 if x > 100 else 0 for x in fair_df.chg_in_tax_owed_pv]
            fair_df['predicted_prob_nonblack'] = 1 - fair_df['predicted_prob_black']

            fp_cls_b = get_fp(fair_df, black=True)
            fp_cls_nb = get_fp(fair_df, black=False)
            s_cls_b = get_s(fair_df, black=True)
            s_cls_nb = get_s(fair_df, black=False)
            cb = get_share_c(fair_df, black=True)
            cnb = get_share_c(fair_df, black=False)
            ncb = get_share_nc(fair_df, black=True)
            ncnb = get_share_nc(fair_df, black=False)
            disp_cls = get_disp(fair_df)

            table = pd.DataFrame(columns = ['rownames', 'Black', 'NonBlack', 'Disparity'])
            table.rownames = ['False-positive Rate', 'Sensitivity', 'Share Compliant', 'Observed Disparity']
            table.Black = [fp_cls_b, s_cls_b, cb, np.nan]
            table.NonBlack = [fp_cls_nb, s_cls_nb, cnb, np.nan]
            table.Disparity = [cb*(fp_cls_b - fp_cls_nb)*100, (1-cb)*(s_cls_b - s_cls_nb)*100, (cnb - cb)*(s_cls_nb - fp_cls_nb)*100, disp_cls*100]
            table[['Black', 'NonBlack', 'Disparity']] = table[['Black', 'NonBlack', 'Disparity']].round(6)

            with open(outdir + 'cls_res_decomp.tex', 'w') as tf:
                tf.write(table.to_latex(index=False))

        # next, 5-fold estimates

        for j in range(5):
            fair_df = pd.DataFrame(columns=col_list)

            for ac in audit_shares:
                copy = test_dict['test_' + str(j)].loc[test_dict['test_' + str(j)].activity_code == (ac + 200)].copy(deep=True)
                copy.sort_values(by=['sort_var', 'chg_in_tax_owed'], inplace=True, ascending=False)
                copy['wgt_chg'] = [x*y for x, y in zip(copy.base_weight, copy.chg_in_tax_owed_pv)]
                copy['cumsum'] = copy.base_weight.cumsum()

                # determine audit budget
                popsize = copy['base_weight'].sum()
                #aud = math.ceil(popsize*audit_alloc[ac]*alpha)
                aud = math.ceil(test_dict['test_' + str(j)].base_weight.sum()*audit_shares[ac]*(i/100))

                # get initial audit assignment
                copy['aud_ind'] = [1 if x < aud else 0 for x in copy['cumsum']]

                # determine how much of audit budget used so far
                actual = copy.loc[copy.aud_ind==1, 'base_weight'].sum()

                # if there is any audit budget remaining, use last of it on portion of next research_audits observation
                if actual/aud < 1:
                    # locate row corresponding to next observation up for audit
                    row = copy.index[(copy.aud_ind.values != 1).argmax()]
                    # determine remaining audit budget
                    dif = aud - actual
                    # copy row corresponding to next observation up for audit and spend remaining audit budget on it
                    dupe = copy.loc[row].copy()
                    dupe['aud_ind'] = 1
                    dupe['base_weight'] = dif
                    dupe['wgt_chg'] = dupe['base_weight']*dupe['chg_in_tax_owed_pv']
                    # adjust weight on non-audited remainder of row
                    copy.loc[row, 'base_weight'] = copy.loc[row, 'base_weight'] - dif
                    # append audited portion of row to copy
                    copy = copy.append(dupe)

                # append audit assignments from this activity code to the results dataframe for this fold and budget
                fair_df = fair_df.append(copy)

            # define some varbs
            fair_df['total_ref_adj'] = fair_df['aud_ind'] * fair_df['base_weight'] * fair_df['ref_cred_amt_dif_pv']
            fair_df['revenue'] = fair_df['aud_ind']*fair_df['wgt_chg']
            fair_df['cost'] = fair_df['aud_ind']*fair_df['base_weight']*fair_df['exp_cost']
            fair_df['net_rev'] = fair_df['revenue'] - fair_df['cost']

            for varb in out['annualized_varb_list']:
                outcome_dict['5_fold_' + varb + '_dict'][str(i) + 'p'][j] = get_annual(fair_df, varb, yr_wgts_dict['yr_wgts_' + str(j)])

            for varb in out['avg_varb_list']:
                outcome_dict['5_fold_' + varb + '_dict'][str(i) + 'p'][j] = get_avg(fair_df, varb)

            # Compute fairness using weighted chen + reg estimators
            outcome_dict['5_fold_chen_fair_dict'][str(i) + 'p'][j] = unc.chenEstimate(fair_df, pbvarb='predicted_prob_black', outcome='aud_ind', wvarb='base_weight')
            fair_df['predicted_prob_black']=pd.to_numeric(fair_df['predicted_prob_black'])
            fair_df['aud_ind']=pd.to_numeric(fair_df['aud_ind'])
            fair_df['base_weight']=pd.to_numeric(fair_df['base_weight'])           
            coef, se = unc.regEstimate(fair_df, pbvarb='predicted_prob_black', outcome='aud_ind', wvarb='base_weight')
            outcome_dict['5_fold_reg_fair_dict'][str(i) + 'p'][j] = coef

            five_fold['f' + str(j) + '_revenue'][budgets.index(i)] = get_annual(fair_df, 'revenue', yr_wgts_dict['yr_wgts_' + str(j)])
            five_fold['f' + str(j) + '_chen_fair'][budgets.index(i)] = unc.chenEstimate(fair_df, pbvarb='predicted_prob_black', outcome='aud_ind', wvarb='base_weight')
            five_fold['f' + str(j) + '_reg_fair'][budgets.index(i)] = coef

        for varb in out['5_fold_outcomes']:
            if 'share' in varb:
                pass
            else:
                plot_dict[str(i) + 'p_mean_' + varb + '_res_' + str(model)].append('{:.4f}'.format(mean(outcome_dict[varb + '_dict'][str(i) + 'p'])))
                plot_dict[str(i) + 'p_std_' + varb + '_res_' + str(model)].append('{:.4f}'.format(np.std(outcome_dict[varb + '_dict'][str(i) + 'p'])/np.sqrt(5)))

        # last, bootstrap estimates
        if bootstrap == True:
            for j in range(bootstrap_iters):
                fold = full_data.sample(frac=1, replace=True, random_state=j)
                fold_wgts = fold.groupby('study_year')['base_weight'].sum().to_dict()
                yr_wgts = {x: full_pop_wgts[x]/fold_wgts[x] for x in full_pop_wgts.keys()}
                fair_df = pd.DataFrame(columns=col_list)

                for ac in audit_shares:
                    copy = fold.loc[fold.activity_code == (ac + 200)].copy(deep=True)
                    copy.sort_values(by=['sort_var', 'chg_in_tax_owed'], inplace=True, ascending=False)
                    # added next line on 4/27/23 in anticipation of ValueError thrown by 'copy.loc[row, 'base_weight'] = copy.loc[row, 'base_weight'] - dif' above
                    copy.reset_index(inplace = True, drop = True)
                    copy['wgt_chg'] = [x*y for x, y in zip(copy.base_weight, copy.chg_in_tax_owed_pv)]
                    copy['cumsum'] = copy.base_weight.cumsum()

                    # determine audit budget
                    popsize = copy['base_weight'].sum()
                    #aud = math.ceil(popsize*audit_alloc[ac]*alpha)
                    aud = math.ceil(fold.base_weight.sum()*audit_shares[ac]*(i/100))

                    # get initial audit assignment
                    copy['aud_ind'] = [1 if x < aud else 0 for x in copy['cumsum']]

                    # determine how much of audit budget used so far
                    actual = copy.loc[copy.aud_ind==1, 'base_weight'].sum()

                    # if there is any audit budget remaining, use last of it on portion of next research_audits observation
                    if actual/aud < 1:
                        # locate row corresponding to next observation up for audit
                        row = copy.index[(copy.aud_ind.values != 1).argmax()]
                        # determine remaining audit budget
                        dif = aud - actual
                        # copy row corresponding to next observation up for audit and spend remaining audit budget on it
                        dupe = copy.loc[row].copy()
                        dupe['aud_ind'] = 1
                        dupe['base_weight'] = dif
                        dupe['wgt_chg'] = dupe['base_weight']*dupe['chg_in_tax_owed_pv']
                        # adjust weight on non-audited remainder of row
                        copy.loc[row, 'base_weight'] = copy.loc[row, 'base_weight'] - dif
                        # append audited portion of row to copy
                        copy = copy.append(dupe)
                 
                    # append audit assignments from this activity code to the results dataframe for this fold and budget
                    fair_df = fair_df.append(copy)

                # define some varbs
                fair_df['total_ref_adj'] = fair_df['aud_ind'] * fair_df['base_weight'] * fair_df['ref_cred_amt_dif_pv']
                fair_df['revenue'] = fair_df.aud_ind*fair_df.wgt_chg
                fair_df['cost'] = fair_df.aud_ind*fair_df.base_weight*fair_df.exp_cost
                fair_df['net_rev'] = fair_df.revenue - fair_df.cost
                fold_wgts = fold.groupby('study_year')['base_weight'].sum().to_dict()
                yr_wgts = {x: full_pop_wgts[x]/fold_wgts[x] for x in full_pop_wgts.keys()}

                for varb in out['annualized_varb_list']:
                    outcome_dict['bootstrap_' + varb + '_dict'][str(i) + 'p'][j] = get_annual(fair_df, varb, yr_wgts)

                for varb in out['avg_varb_list']:
                    outcome_dict['bootstrap_' + varb + '_dict'][str(i) + 'p'][j] = get_avg(fair_df, varb)

                # Compute fairness using weighted chen + reg estimators
                outcome_dict['bootstrap_chen_fair_dict'][str(i) + 'p'][j] = unc.chenEstimate(fair_df, pbvarb='predicted_prob_black', outcome='aud_ind', wvarb='base_weight')
                #fair_df.aud_ind = pd.to_numeric(fair_df.aud_ind)
                fair_df['predicted_prob_black']=pd.to_numeric(fair_df['predicted_prob_black'])
                fair_df['aud_ind']=pd.to_numeric(fair_df['aud_ind'])
                fair_df['base_weight']=pd.to_numeric(fair_df['base_weight'])
                coef, se = unc.regEstimate(fair_df, pbvarb='predicted_prob_black', outcome='aud_ind', wvarb='base_weight')
                outcome_dict['bootstrap_reg_fair_dict'][str(i) + 'p'][j] = coef

            if i == 1.45 and bootstrap == True and write_bootstrap_fair == True:
                bootstrap_dat = pd.DataFrame(columns=['reg_fair', 'chen_fair'])
                bootstrap_dat.chen_fair = outcome_dict['bootstrap_chen_fair_dict'][str(i) + 'p']
                bootstrap_dat.reg_fair = outcome_dict['bootstrap_reg_fair_dict'][str(i) + 'p']
                bootstrap_dat.to_csv(outdir + 'res_' + str(model) + '_bootstrap_fair_obs.csv')
            for varb in out['bootstrap_outcomes']:
                if 'share' in varb:
                    pass
                else:
                    plot_dict[str(i) + 'p_mean_' + varb + '_res_' + str(model)].append('{:.4f}'.format(mean(outcome_dict[varb + '_dict'][str(i) + 'p'])))
                    plot_dict[str(i) + 'p_std_' + varb + '_res_' + str(model)].append('{:.4f}'.format(np.std(outcome_dict[varb + '_dict'][str(i) + 'p'])))

    with open(outdir + 'res_' + str(model) + '_fold_by_fold_rev_and_fair.pickle', 'wb') as handle:
        pickle.dump(five_fold, handle, protocol=pickle.HIGHEST_PROTOCOL)

    if model=='oracle' or model=='reg':
        if eitc==False and activity_code is None:
            with open(outdir + 'res_' + str(model) + '_output_new_bifsg.pickle', 'wb') as handle:
                pickle.dump(plot_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
        elif eitc==True and dep_database==False and activity_code is None:
            with open(outdir + 'eitc_res_' + str(model) + '_output_new_bifsg.pickle', 'wb') as handle:
                pickle.dump(plot_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
        elif eitc==True and dep_database==True and activity_code is None:
            with open(outdir + 'eitc_res_' + str(model) + '_plus_dep_database_output_new_bifsg.pickle', 'wb') as handle:
                pickle.dump(plot_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
        elif activity_code is not None:
            with open(outdir + 'res_' + str(activity_code) + '_' + str(model) + '_dep_database_' + str(dep_database) + '_eitc_' + str(eitc) + '_output_new_bifsg.pickle', 'wb') as handle:
                pickle.dump(plot_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
    elif model=='cls':
        if eitc==False and activity_code is None:
            with open(outdir + 'res_' + str(model) + '_' + str(thresh) + '_output_new_bifsg.pickle', 'wb') as handle:
                pickle.dump(plot_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
        elif eitc==True and dep_database==False and activity_code is None:
            with open(outdir + 'eitc_res_' + str(model) + '_' + str(thresh) + '_output_new_bifsg.pickle', 'wb') as handle:
                pickle.dump(plot_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
        elif eitc==True and dep_database==True and activity_code is None:
            with open(outdir + 'eitc_res_' + str(model) + '_' + str(thresh) + '_plus_dep_database_output_new_bifsg.pickle', 'wb') as handle:
                pickle.dump(plot_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
        elif activity_code is not None:
            with open(outdir + 'res_' + str(activity_code) + '_' + str(model) + '_dep_database_' + str(dep_database) + '_eitc_' + str(eitc) + '_output_new_bifsg.pickle', 'wb') as handle:
                pickle.dump(plot_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
    return plot_dict



